{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import time\n",
    "import tqdm\n",
    "import itertools\n",
    "\n",
    "# uses allennlp modules\n",
    "from allennlp.nn import util\n",
    "from allennlp.nn.beam_search import BeamSearch\n",
    "\n",
    "# imports chinese gpt\n",
    "from chinese_gpt import TransformerDecoderLM\n",
    "\n",
    "# uses bert chinese wordpiece tokenization\n",
    "from pytorch_pretrained_bert import BertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_logits(logits, k):\n",
    "    \"\"\"Mask logits so that only top-k logits remain\n",
    "    \"\"\"\n",
    "    values, _ = torch.topk(logits, k)\n",
    "    min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])\n",
    "    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Bert tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenize\n",
      "['今', '天', '北', '京', '天', '气', '出', '现', '小', '雨', ',', '山', '区', '还', '出', '现', '了', '降', '雪', ',', '气', '温', '下', '降', ',', '体', '感', '十', '分', '寒', '冷', '。']\n",
      "Tokens to ids\n",
      "[791, 1921, 1266, 776, 1921, 3698, 1139, 4385, 2207, 7433, 117, 2255, 1277, 6820, 1139, 4385, 749, 7360, 7434, 117, 3698, 3946, 678, 7360, 117, 860, 2697, 1282, 1146, 2170, 1107, 511]\n",
      "Ids to tokens\n",
      "['今', '天', '北', '京', '天', '气', '出', '现', '小', '雨', ',', '山', '区', '还', '出', '现', '了', '降', '雪', ',', '气', '温', '下', '降', ',', '体', '感', '十', '分', '寒', '冷', '。']\n"
     ]
    }
   ],
   "source": [
    "sentence = \"今天北京天气出现小雨,山区还出现了降雪,气温下降,体感十分寒冷。\"\n",
    "print(\"Tokenize\")\n",
    "print(tokenizer.tokenize(sentence))\n",
    "print(\"Tokens to ids\")\n",
    "ids = tokenizer.convert_tokens_to_ids(sentence)\n",
    "print(ids)\n",
    "print(\"Ids to tokens\")\n",
    "print(tokenizer.convert_ids_to_tokens(ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "# make sure your model is on GPU\n",
    "device = torch.device(\"cuda\")\n",
    "\n",
    "model = TransformerDecoderLM()\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load weights into the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_state_dict = torch.load(\"model_state_epoch_62.th\", map_location=lambda storage, loc: storage)\n",
    "new_state_dict = model.state_dict()\n",
    "\n",
    "for item in new_state_dict.keys():\n",
    "    new_state_dict[item] = old_state_dict['module.'+item]\n",
    "    \n",
    "model.load_state_dict(new_state_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conditioanl or Unconditional Decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ask more about news\n",
    "prompt = tokenizer.tokenize(\"房价上涨的原因主要有哪些?\")\n",
    "prompt = tokenizer.convert_tokens_to_ids(prompt)\n",
    "#prompt = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'一般来说，房价上涨主要是由于土地出让价不高,房产价格高低的差异,房价与土地价格之差一般也较小。这里是从房地产价格的涨跌变化入手,房地产上涨与土地价格的一般变化主要通过房地产价格的变动来区分。由于地区房价的上涨与房地产价格的上涨,房产价格上涨的房价与房价之差也一定大,但是房价上涨与房地产价格的不变的差别可不是单个数量差异。一是从目前房价的变幅上来看,房价与房地产价格总变动,房价比房地产发展速度快,房地产价格比房地产发展速度慢;二是房地产市场的持续供应不足、供应过剩,房地产价格与房地产价格的变动,需要进一步深入研究房企的资产整体价格变化,并进行分析。一般来说，房价比房地产价格高,房地产价格也比房地产价格高。但是,房价的变化是可以间接计算的,所以,房地产价格与房地产价格也是可以用现成价值表示的,并且有一定的随机性。因为,房地产价格的变化并不需要房地产企业经过多少年的沉淀,房地产企业也不可能长久地维持这种地价的变动,只有这样,房价才有可能成为房价的基础。房地产行情是房国资企业经营的重要因素,对房企来说,只要有房地产企业在中国,房地产业的生存方面有了较好的基础,那么,房地产商对房屋的需求就会相对容'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_k = 20\n",
    "temperature = 1.0\n",
    "length = 0\n",
    "\n",
    "start_predictions = torch.LongTensor([[101] + prompt]* batch_size).to(device)\n",
    "mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # cache saves in past\n",
    "    logits, past = model(start_predictions, mask, past=None, past_length=0)\n",
    "    logits = logits[:, -1, :] / temperature\n",
    "    logits = top_k_logits(logits, k=top_k)\n",
    "\n",
    "    sentence = []\n",
    "\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    prob, prev_pred = torch.topk(probs, k=1, dim=-1)\n",
    "    sentence.append(prev_pred)\n",
    "    length += 1\n",
    "\n",
    "    # decoding loop\n",
    "    for i in range(500):\n",
    "        mask = F.pad(mask, (0, 1), \"constant\", 1.0)\n",
    "        logits, past = model(prev_pred, mask, past=past, past_length=length)\n",
    "        logits = logits.squeeze(1) / temperature\n",
    "        logits = top_k_logits(logits, k=top_k)\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        prev_pred = torch.multinomial(probs, num_samples=1)\n",
    "        sentence.append(prev_pred)\n",
    "        length += 1\n",
    "\n",
    "    sentence = torch.cat(sentence, dim=-1)\n",
    "\n",
    "\"\".join(tokenizer.convert_ids_to_tokens(sentence[0].tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
